import numpy as np, matplotlib.pyplot as plt, seaborn as sns, numpy as np, sys, os, tqdm, lmfit as lm
import pandas as pd, scipy.stats as stats, scipy.optimize as op, scipy.interpolate as ip, pathlib
pd.options.mode.chained_assignment = None
label = {'gb': 'Spheres', 'ng': 'Natural gravel', 'gg': 'Rounded chips', 'ls': 'Rectangular prisms', 'oc': 'Faceted ellipsoids'}
color = {'gb': sns.xkcd_rgb['cobalt blue'], 'gg': sns.xkcd_rgb['cerulean'], 'oc': sns.xkcd_rgb['blue green'], 'ng': sns.xkcd_rgb['maize'], 'ls': sns.xkcd_rgb['blood orange'], 'bm': 'y'}
marker = {'gb': 'o', 'gg': 'D', 'oc': '*', 'ng': 'v', 'ls': 's'}
size = {'gb': 9, 'gg': 9, 'oc': 18, 'ng': 12, 'ls': 9}

class Grain:
    def __init__(self, name, ARS, Meas_SV, VES_SV, density, shape, df, y_data_lim, fixed_params):
        self.name = name
        self.tau_fix = fixed_params[0]
        self.c_fix = fixed_params[1]
        self.y_data_lim = y_data_lim
        self.color = color[self.name]
        self.ar = ARS.mean_angle
        self.ar_glass = 24.
        self.dar = ARS.std_angle
        self.df_clean = df.copy()
        self.dn = shape.dn
        self.d = self.df_clean.flow_depth.values # mean water depth
        self.D = shape.dn
        self.dD = shape.ddn
        self.CSF = shape.CSF
        self.dCSF = shape.dCSF
        self.CD = Meas_SV.CD
        self.CD_VES = VES_SV.CD_VES
        self.E = shape.C * np.sqrt(3 / (shape.A**2 + shape.B**2 + shape.C**2))
        self.asprat = shape.C / shape.A
        self.ros = density.total_grain_ro
        self.dros = density.dtotal_grain_ro
        self.rof = self.df_clean.fluid_density.values[0]
        self.drof = self.df_clean.dfluid_density.values[0]
        self.wves = VES_SV.ws_sphere
        self.dwves = VES_SV.dws_sphere
        self.wmeas = Meas_SV.vel_mean
        self.dwmeas = Meas_SV.dvel_mean
        self.g = 9.8
        self.b = 0.0105 # m
        self.db = .001 # error in width measurement
        self.mustar = (np.tan(np.radians(self.ar)) - np.tan(np.radians(3.5))) / (np.tan(np.radians(self.ar_glass)) - np.tan(np.radians(3.5)))
        self.dmustar = np.abs(np.radians(self.dar) / (np.cos(np.radians(self.ar))**2 * (np.tan(np.radians(self.ar_glass)) - np.tan(np.radians(3.5)))))

        # normalizing constant for sediment flux
        self.R = (self.ros/self.rof - 1)
        self.qs_norm = np.sqrt(self.g*self.D*self.R)*self.D
        self.dqs_norm = np.sqrt((self.g*self.D**3)/(4*self.rof*(self.ros-self.rof))) * np.sqrt( (3*(self.ros-self.rof)*self.dD/self.D)**2 + (self.ros*self.drof/self.rof)**2 + self.dros**2)/self.ros

        # normalizing constant for shear stress
        self.tau_norm = (self.ros - self.rof)*self.g*self.D
        self.dtau_norm = self.tau_norm * np.sqrt((self.dros/self.ros)**2 + (self.drof/self.rof)**2 + (self.dD/self.D)**2)
        self.taub = 2.41 * self.df_clean.taub.values  # apply side wall correction factor (x2.41)
        self.dtaub = self.df_clean.dtaub.values
        self.tau8 = self.taub / self.tau_norm 
        self.dtau8 = self.tau8 * np.sqrt((self.dtaub/self.taub)**2 + (self.dtau_norm/self.tau_norm)**2)
        self.CDx = (self.wves / self.wmeas)**2 * self.CSF 
        self.dCDx = 2 * self.CDx * np.sqrt((self.dwmeas/self.wmeas)**2 + (self.dwves/self.wves)**2 + ((1/2)*self.dCSF/self.CSF)**2)
        self.tau8_CDx = self.CDx * self.taub / self.tau_norm
        self.dtau8_CDx = self.tau8 * np.sqrt((self.dCDx/self.CDx)**2 + (self.dtaub/self.taub)**2 + (self.dtau_norm/self.tau_norm)**2)

        self.tau8_mus = (1/self.mustar) * self.taub / self.tau_norm
        self.dtau8_mus = self.tau8 * np.sqrt((self.dmustar/self.mustar)**2 + (self.dtaub/self.taub)**2 + (self.dtau_norm/self.tau_norm)**2)

        self.qs = self.df_clean.mass_flux / (self.ros*self.b)
        self.dqs = self.qs * np.sqrt( (self.df_clean.dmass_flux.values/self.df_clean.mass_flux.values)**2 + (self.dros/self.ros)**2 + (self.db/self.b)**2 )
        self.q8 = self.qs / self.qs_norm
        self.dq8 = self.q8 * np.sqrt( (self.dqs/self.qs)**2 + (self.dqs_norm/self.qs_norm)**2 )

    def fit_model_standard(self):

        def func(x, c, xc): return c*(x - xc)**(3/2)

        self.x_data = self.tau8
        self.dx_data = self.dtau8
        self.y_data = self.q8
        self.dy_data = self.dq8

        self.model = lm.Model(func)
        self.model.set_param_hint('c', value=0.01, min=0)
        self.model.set_param_hint('xc', value=0.1, min=0, max=np.nanmin(self.x_data))
        self.params = self.model.make_params()
        self.dx_plot = self.dx_data.copy()
        self.dy_plot = self.dy_data.copy()
        self.x_plot = self.x_data.copy()
        self.y_plot = self.y_data.copy()
        self.results = self.model.fit(self.y_data[np.isfinite(self.x_data)], self.params, x=self.x_data[np.isfinite(self.x_data)], weights=1/self.dy_data[np.isfinite(self.x_data)])
        self.xc = self.results.result.params['xc'].value
        self.dxc = self.results.result.params['xc'].stderr
        self.c = self.results.result.params['c'].value
        self.dc = self.results.result.params['c'].stderr
        self.rmsd_standard = np.sqrt(np.mean((np.log10(self.y_data) - np.log10(func(self.x_data, self.c, self.xc)))**2))

    def fit_model_CDx(self):
        def func(x, c, xc): return c*(x - xc)**(3/2)

        self.x_data = self.tau8_CDx
        self.dx_data = self.dtau8_CDx
        self.y_data = self.q8
        self.dy_data = self.dq8

        self.model = lm.Model(func)
        self.model.set_param_hint('c', value=.01, min=0)
        self.model.set_param_hint('xc', value=.05, min=.01, max=np.nanmin(self.x_data))
        self.params = self.model.make_params()
        self.dx_plot = self.dx_data.copy()
        self.dy_plot = self.dy_data.copy()
        self.x_plot = self.x_data.copy()
        self.y_plot = self.y_data.copy()

        self.results = self.model.fit(self.y_data[np.isfinite(self.x_data)], self.params, weights=1/self.dy_data[np.isfinite(self.x_data)], x=self.x_data[np.isfinite(self.x_data)])
        self.xc = self.results.result.params['xc'].value
        self.dxc = self.results.result.params['xc'].stderr
        self.c = self.results.result.params['c'].value
        self.dc = self.results.result.params['c'].stderr

    def fit_model_mus(self):
        def func(x, c, xc): return c*(x - xc)**(3/2)

        self.x_data = self.tau8_mus
        self.dx_data = self.dtau8_mus
        self.y_data = self.q8
        self.dy_data = self.dq8

        self.model = lm.Model(func)
        self.model.set_param_hint('c', value=.01, min=0)
        self.model.set_param_hint('xc', value=0.05, min=0.01, max=np.nanmin(self.x_data))
        self.params = self.model.make_params()
        self.dx_plot = self.dx_data.copy()
        self.dy_plot = self.dy_data.copy()
        self.x_plot = self.x_data.copy()
        self.y_plot = self.y_data.copy()

        self.results = self.model.fit(self.y_data[np.isfinite(self.x_data)], self.params, weights=1/self.dy_data[np.isfinite(self.x_data)], x=self.x_data[np.isfinite(self.x_data)])
        self.xc = self.results.result.params['xc'].value
        self.dxc = self.results.result.params['xc'].stderr
        self.c = self.results.result.params['c'].value
        self.dc = self.results.result.params['c'].stderr

    def plot_fit_mus(self, ax=None, errorbars=True, offset=True):
        label = {'gb': 'Spheres', 'ng': 'Natural gravel', 'gg': 'Rounded chips', 'ls': 'Rectangular prisms', 'oc': 'Faceted ellipsoids'}
        if not ax: fig, ax = plt.subplots()

        self.fit_model_mus()

        def func(x, c, xc): return xc + (x/c)**(2./3.)
        F2 = 1.0 
        F1 = 1.0 

        y = np.logspace(-4,1,1000)
        fit = func(y, self.c, self.xc)
        fit_up = func(y, self.c-self.dc*1.96, self.xc+self.dxc*1.96)
        fit_dw = func(y, self.c+self.dc*1.96, self.xc-self.dxc*1.96)

        ax.errorbar(F1*self.x_plot, F2*self.y_plot, xerr=F1*self.dx_plot, yerr=F2*self.dy_plot, ecolor=sns.xkcd_rgb['grey'], elinewidth=.5, capthick=.5, capsize=1, color=self.color, markeredgecolor='w', markeredgewidth=1, ms=size[self.name], fmt=marker[self.name], label=label[self.name])
        ax.plot(F1*fit, y*F2, self.color, ls='-', lw=1, alpha=1)
        ax.fill_betweenx(F2*y, F1*fit_up, F1*fit_dw, alpha=.05, color=self.color, linewidth=0.0)
        fontsize = 22
        ax.set_xlabel(r'Friction-corrected Shields number, $\tau^*/\mu^*$',fontsize=fontsize)
        ax.set_ylabel(r'Nondimensional flux, $q^*$',fontsize=fontsize)
        ax.set_xscale('log')
        ax.set_yscale('log')

    def plot_fit_standard(self, ax=None, errorbars=True, offset=True):
        label = {'gb': 'Spheres', 'ng': 'Natural gravel', 'gg': 'Rounded chips', 'ls': 'Rectangular prisms', 'oc': 'Faceted ellipsoids'}
        if not ax: fig, ax = plt.subplots()

        self.fit_model_standard()

        def func(x, c, xc): return xc + (x/c)**(2./3.)

        y = np.logspace(-4,1,1000)
        self.co = self.c*(self.mustar/self.CDx)**1.5
        self.xco = self.xc/(self.mustar/self.CDx)
        print(self.name, self.c, self.xc)
        fit = func(y, self.c, self.xc)
        fit_up = func(y, self.c-self.dc*1.96, self.xc+self.dxc*1.96)
        fit_dw = func(y, self.c+self.dc*1.96, self.xc-self.dxc*1.96)

        ax.errorbar(self.x_plot, self.y_plot, xerr=self.dx_plot, yerr=self.dy_plot, ecolor=sns.xkcd_rgb['grey'], elinewidth=.5, capthick=.5, capsize=1, color=self.color, markeredgecolor='w', markeredgewidth=1, ms=size[self.name], fmt=marker[self.name], label=label[self.name])
        ax.plot(fit, y, self.color, ls='-', lw=1, alpha=1)
        ax.fill_betweenx(y, fit_up, fit_dw, alpha=.05, color=self.color, linewidth=0.0)
        fontsize = 22
        ax.set_xlim(.015,.2)
        ax.set_ylim(.0005)
        ax.set_xlabel(r'Conventional Shields number, $\tau^*$',fontsize=fontsize)
        ax.set_ylabel(r'Nondimensional flux, $q^*$',fontsize=fontsize)

        ax.set_xscale('log')
        ax.set_yscale('log')

    def plot_fit_CDx(self, ax=None, errorbars=True, offset=True):
        label = {'gb': 'Spheres', 'ng': 'Natural gravel', 'gg': 'Rounded chips', 'ls': 'Rectangular prisms', 'oc': 'Faceted ellipsoids'}
        if not ax: fig, ax = plt.subplots()

        self.fit_model_CDx()

        def func(x, c, xc): return xc + (x/c)**(2./3.)
        F2 = 1.0 
        F1 = 1.0 

        y = np.logspace(-4,1,1000)
        fit = func(y, self.c, self.xc)
        fit_up = func(y, self.c-self.dc*1.96, self.xc+self.dxc*1.96)
        fit_dw = func(y, self.c+self.dc*1.96, self.xc-self.dxc*1.96)

        ax.errorbar(F1*self.x_plot, F2*self.y_plot, xerr=F1*self.dx_plot, yerr=F2*self.dy_plot, ecolor=sns.xkcd_rgb['grey'], elinewidth=.5, capthick=.5, capsize=1, color=self.color, markeredgecolor='w', markeredgewidth=1, ms=size[self.name], fmt=marker[self.name], label=label[self.name])
        ax.plot(F1*fit, y*F2, self.color, ls='-', lw=1, alpha=1)
        ax.fill_betweenx(F2*y, F1*fit_up, F1*fit_dw, alpha=.05, color=self.color, linewidth=0.0)
        fontsize = 22
        ax.set_xlabel(r'Drag-corrected Shields number, $C^*\tau^*$',fontsize=fontsize)
        ax.set_ylabel(r'Nondimensional flux, $q^*$',fontsize=fontsize)
        ax.set_xscale('log')
        ax.set_yscale('log')

    def plot_fit_final_new(self, F1_in, alpha, tauc, ax=None, errorbars=True, offset=True):
        label = {'gb': 'Spheres', 'ng': 'Natural gravel', 'gg': 'Rounded chips', 'ls': 'Rectangular prisms', 'oc': 'Faceted ellipsoids'}
        if not ax: fig, ax = plt.subplots()

        F2 = 1.0 
        F1 = F1_in
        
        ax.errorbar(F1*self.tau8, F2*self.q8, xerr=F1*self.dtau8, yerr=F2*self.dq8, ecolor=sns.xkcd_rgb['grey'], elinewidth=.5, capthick=.5, capsize=1, color=self.color, markeredgecolor='w', markeredgewidth=1, ms=size[self.name], fmt=marker[self.name])#, label=label[self.name])

        fontsize = 22
        ax.set_xlabel(r'Shape-corrected Shields number, $(C^*/\mu^{*})\tau^*$',fontsize=fontsize)
        ax.set_ylabel(r'Nondimensional flux, $q^*$',fontsize=fontsize)
        ax.set_xscale('log')
        ax.set_yscale('log')


class Fits:
    def __init__(self, grains=['oc', 'gb', 'gg', 'ls', 'ng'], cdxflag=0, fixed_params=(np.nan, np.nan)):
        self.grains = grains
        self.cdxflag = cdxflag
        self.mean_bed_angle = 3.5
        self.ar_glass = 24.
        self.df_density = pd.read_csv('1_grain_properties/1_var_density.csv').set_index('Unnamed: 0').rename_axis(None)
        self.df_shape = pd.read_csv('1_grain_properties/2_var_shape.csv').set_index('Unnamed: 0').rename_axis(None)
        self.df_ARS = pd.read_csv('1_grain_properties/3_var_ARS.csv').set_index('Unnamed: 0').rename_axis(None)
        self.df_Meas_SV = pd.read_csv('1_grain_properties/4_var_settling_velocities.csv').set_index('Unnamed: 0').rename_axis(None)
        self.df_VES_SV = pd.read_csv('1_grain_properties/5_var_VES_settling_velocity.csv').set_index('Unnamed: 0').rename_axis(None)
        self.df = pd.read_csv('3_exp_meta_data/exp_data_6_2022.txt').drop('Unnamed: 0', axis=1).set_index('names')
        self.df_all_data = pd.read_csv('3_exp_meta_data/exp_data_6_2022.txt').drop('Unnamed: 0', axis=1).set_index('names')
        self.y_data_lim = 2e-2
        self.fixed_params = fixed_params

        self.xf = {}
        for name, group in self.df.groupby('grain_kind'):
            self.xf[name] = Grain(name, self.df_ARS.loc[name], self.df_Meas_SV.loc[name], self.df_VES_SV.loc[name], self.df_density.loc[name], self.df_shape.loc[name], group.sort_values(by='mass_flux'), self.y_data_lim, fixed_params=self.fixed_params)

    def plot_standard(self, grains=None, ax=None):
        if not grains: grains = self.grains
        fontsize = 18
        if ax is None: fig, ax = plt.subplots(figsize=(8,6))
        for name in grains: self.xf[name].plot_fit_standard(ax=ax, offset=False)
        ax.legend(loc=4, fontsize=fontsize)
        ax.tick_params(labelsize=fontsize)
        ax.set_ylim(1e-3, 2)
        ax.set_xlim(.04, 0.3)

    def plot_CDx(self, grains=None, ax=None):
        if not grains: grains = self.grains
        fontsize = 18
        if ax is None: fig, ax = plt.subplots(figsize=(8,6))

        for name in grains: self.xf[name].plot_fit_CDx(ax=ax, offset=False)
        ax.tick_params(labelsize=fontsize)
        ax.set_ylim(1e-3, 2)
        ax.set_xlim(.02, 0.5)

    def plot_final_new(self, grains=None, ax=None):
            if not grains: grains = self.grains
            fontsize = 18
            if ax is None: fig, ax = plt.subplots(figsize=(8,6))

            c = {}
            cux_data = []
            x_data = []
            dx_data = []
            y_data = []
            dy_data = []
            self.plot_coeff(plot=False)
            self.plot_taucs(plot=False)
            for key in grains:
                c[key] = self.xf[key].CDx / self.xf[key].mustar
                y_data.append(self.xf[key].q8)
                cux_data.append(c[key]*self.xf[key].tau8)
                x_data.append(self.xf[key].tau8)
                dy_data.append(self.xf[key].dq8)
                dx_data.append(self.xf[key].dtau8)
            alpha_o = 11.991
            tau_o = 0.0506
            def func(x, c, xc): return c*(x - xc)**(3/2)
            cux_data = np.hstack(cux_data)
            y_data = np.hstack(y_data)
            self.rmsd_final = np.sqrt(np.nanmean((np.log10(y_data) - np.log10(func(cux_data, alpha_o, tau_o)))**2))

            for name in grains: self.xf[name].plot_fit_final_new(c[name], alpha=alpha_o, tauc=tau_o, ax=ax)

            func = lambda x, a, t: a*(x - t)**(3/2)
            x = np.logspace(-2,0,1000)
            ax.plot(x, func(x, alpha_o, tau_o), 'k--', label='Shape-corrected transport law')

            x_data = np.hstack(x_data)
            dx_data = np.hstack(dx_data)
            dy_data = np.hstack(dy_data)

            x_data = x_data[np.isfinite(y_data)]
            dx_data = dx_data[np.isfinite(y_data)]
            dy_data = dy_data[np.isfinite(y_data)]
            y_data = y_data[np.isfinite(y_data)]
            dx_data = dx_data[np.isfinite(x_data)]
            dy_data = dy_data[np.isfinite(x_data)]
            y_data = y_data[np.isfinite(x_data)]
            x_data = x_data[np.isfinite(x_data)]

            [self.xf[grain].fit_model_standard() for grain in ['gb', 'oc', 'gg', 'ng', 'ls']]
            c = np.array([self.xf[grain].c for grain in ['gb', 'oc', 'gg', 'ng', 'ls']]).mean()
            xc = np.array([self.xf[grain].xc for grain in ['gb', 'oc', 'gg', 'ng', 'ls']]).mean()
            self.rmsd_standard_all = np.sqrt(np.nanmean((np.log10(y_data) - np.log10(func(x_data, c, xc)))**2))

            ax.legend(loc=4, fontsize=fontsize-2)
            ax.tick_params(labelsize=fontsize)
            ax.set_ylim(1e-3, 2)
            ax.set_xlim(0.04, 0.3)

    def plot_mus(self, grains=None, ax=None):
        if not grains: grains = self.grains
        fontsize = 18
        if ax is None: fig, ax = plt.subplots(figsize=(8,6))

        for name in grains: self.xf[name].plot_fit_mus(ax=ax, offset=False)
        ax.tick_params(labelsize=fontsize)
        ax.set_ylim(1e-3, 2)
        ax.set_xlim(.02, 0.5)
        
    def plot_coeff(self, grains=None, ax=None, plot=True):
        if not grains: grains = self.grains
        fontsize = 18

        func = lambda x, m, b: m*x**b

        if plot is True:
            if ax is None:
                fig, ax = plt.subplots(figsize=(8,6))

            ax.tick_params(labelsize=fontsize-2)
            mu, da, a = [], [], []
            for key in grains:
                self.xf[key].fit_model_CDx()
                mu.append(self.xf[key].ar)
                a.append(self.xf[key].c)
                da.append(self.xf[key].dc)
            ax.set_ylabel(r'$\alpha/C^{*3/2}$', fontsize=fontsize+2)
            x_data, dy_data, y_data = np.array(mu), np.array(da), np.array(a)

            model = lm.Model(func)
            model.set_param_hint('b', value=-1.5, vary=False)
            model.set_param_hint('m', value=10, vary=True)
            params = model.make_params()
            results = model.fit(y_data, params, weights=y_data/dy_data, x=((np.tan(np.radians(x_data)) - np.tan(np.radians(self.mean_bed_angle)))/(np.tan(np.radians(self.ar_glass)) - np.tan(np.radians(self.mean_bed_angle)))))
            b = results.result.params['b'].value
            db = results.result.params['b'].stderr
            m = results.result.params['m'].value
            dm = results.result.params['m'].stderr

            x = np.linspace(0.8,2.5,100)
            fit = lambda x: func(x, m, b)
            dy = lambda x: fit(x) * np.sqrt((dm/m)**2 + (np.log(x)*db)**2)

            ax.plot(x, func(x, 11.99, -1.5), 'k-', label=r'$\alpha/C^{*3/2} = \alpha_o/\mu^{\ast 3/2}$')
            for key in grains:
                ax.errorbar(((np.tan(np.radians(self.xf[key].ar)) - np.tan(np.radians(self.mean_bed_angle)))/(np.tan(np.radians(self.ar_glass)) - np.tan(np.radians(3.5)))), self.xf[key].c, xerr=(np.tan(np.radians(self.xf[key].dar)/(np.tan(np.radians(self.ar_glass)) - np.tan(np.radians(3.5))))), yerr=self.xf[key].dc,
                # label=label[key],
                ecolor=sns.xkcd_rgb['dark grey'], elinewidth=1, capthick=.8, capsize=3, color=color[key], markeredgecolor='w', markeredgewidth=1, ms=size[key]-2, fmt=marker[key])

            ax.legend(fontsize=fontsize-2, loc=3)
            ax.set_ylim(0,.025)
            ax.set_xlim(0,1.5)
            ax.ticklabel_format(useOffset=False, style='plain')
            
            ax.set_xlim(0,3)
            ax.set_ylim(0, 15)
            
            # ax.set_xlim(0.2,1.2)
            # ax.set_ylim(5,150)
            # ax.set_xscale('log')
            # ax.set_yscale('log')

        else:
            mu, da, a = [], [], []
            for key in grains:
                self.xf[key].fit_model_CDx()
                mu.append(self.xf[key].ar)
                a.append(self.xf[key].c)
                da.append(self.xf[key].dc)

            x_data, dy_data, y_data = np.array(mu), np.array(da), np.array(a)

            model = lm.Model(func)
            model.set_param_hint('b', value=-1.5, vary=False)
            model.set_param_hint('m', value=1.0)
            params = model.make_params()
            results = model.fit(y_data, params, weights=y_data/dy_data, x=np.tan(np.radians(x_data)))
            self.alpha_b = results.result.params['b'].value
            self.alpha_db = results.result.params['b'].stderr
            self.alpha_m = results.result.params['m'].value
            self.alpha_dm = results.result.params['m'].stderr
            self.alpha_func = lambda x, m, b: func(np.tan(np.radians(x)), m, b)

    def plot_taucs(self, grains=None, ax=None, plot=True):
        if not grains: grains = self.grains
        fontsize = 18

        func = lambda x, m, b: m*x**b

        if plot is True:
            if ax is None: fig, ax = plt.subplots(figsize=(8,6))

            ax.tick_params(labelsize=fontsize-2)
            mu, da, a = [], [], []
            for key in grains:
                self.xf[key].fit_model_CDx()
                mu.append(self.xf[key].ar)
                a.append(self.xf[key].xc)
                da.append(self.xf[key].dxc)
            ax.set_xlabel(r'Norm. coeff. of friction, $\mu^*$', fontsize=fontsize+2, labelpad=12)
            ax.set_ylabel(r'$C^*\tau_c^*$', fontsize=fontsize+2)
            x_data, dy_data, y_data = np.array(mu), np.array(da), np.array(a)

            model = lm.Model(func)
            model.set_param_hint('b', value=1, vary=False)
            model.set_param_hint('m', value=0.06)
            params = model.make_params()
            results = model.fit(y_data, params, weights=y_data/dy_data, x=((np.tan(np.radians(x_data)) - np.tan(np.radians(self.mean_bed_angle)))/(np.tan(np.radians(self.ar_glass)) - np.tan(np.radians(self.mean_bed_angle)))))
            b = results.result.params['b'].value
            db = results.result.params['b'].stderr
            m = results.result.params['m'].value
            dm = results.result.params['m'].stderr

            x = np.linspace(0.5,2.5,100)
            fit = lambda x: func(x, m, b)
            dy = lambda x: fit(x) * np.sqrt((dm/m)**2 + (np.log(x)*db)**2)
            ax.plot(x, func(x, 0.0506, 1), 'k-', label=r'$C^*\tau_c^* = \tau_{co}^*\mu^\ast$')
            for key in grains:
                ax.errorbar(((np.tan(np.radians(self.xf[key].ar)) - np.tan(np.radians(self.mean_bed_angle)))/(np.tan(np.radians(self.ar_glass)) - np.tan(np.radians(self.mean_bed_angle)))), self.xf[key].xc, xerr=(np.tan(np.radians(self.xf[key].dar)/(np.tan(np.radians(self.ar_glass)) - np.tan(np.radians(3.5))))), yerr=self.xf[key].dxc,
                ecolor=sns.xkcd_rgb['dark grey'], elinewidth=1, capthick=.8, capsize=3, color=color[key], markeredgecolor='w', markeredgewidth=1, ms=size[key]-2, fmt=marker[key])

            ax.legend(fontsize=fontsize-1, loc=4)
            ax.set_xlim(0, 3)
            ax.set_ylim(0,.15)

        else:
            mu, da, a = [], [], []
            for key in grains:
                self.xf[key].fit_model_CDx()
                mu.append(self.xf[key].ar)
                a.append(self.xf[key].xc)
                da.append(self.xf[key].dxc)
            x_data, dy_data, y_data = np.array(mu), np.array(da), np.array(a)

            model = lm.Model(func)
            model.set_param_hint('b', value=1, vary=False)
            model.set_param_hint('m', value=0.06)
            params = model.make_params()
            results = model.fit(y_data, params, weights=y_data/dy_data, x=np.tan(np.radians(x_data)))
            self.xc_b = results.result.params['b'].value
            self.xc_db = results.result.params['b'].stderr
            self.xc_m = results.result.params['m'].value
            self.xc_dm = results.result.params['m'].stderr
            self.xc_func = lambda x, m, b: func(np.tan(np.radians(x)), m, b)

    def plot_coeff_2(self, grains=None, ax=None, plot=True):
        if not grains: grains = self.grains
        fontsize = 18

        func = lambda x, m, b: m*x**b

        if plot is True:
            if ax is None:
                fig, ax = plt.subplots(figsize=(8,6))

            ax.tick_params(labelsize=fontsize-2)
            C, dC, da, a = [], [], [], []
            for key in grains:
                self.xf[key].fit_model_mus()
                C.append(self.xf[key].CDx)
                dC.append(self.xf[key].dCDx)
                a.append(self.xf[key].c)
                da.append(self.xf[key].dc)
            ax.set_ylabel(r'$\alpha\mu^{*3/2}$', fontsize=fontsize+2)

            x_data, dx_data, dy_data, y_data = np.array(C), np.array(dC), np.array(da), np.array(a)

            model = lm.Model(func)
            model.set_param_hint('b', value=1.5, vary=False)
            model.set_param_hint('m', value=0.005, vary=True)
            params = model.make_params()
            results = model.fit(y_data, params, weights=x_data/dx_data, x=x_data)
            b = results.result.params['b'].value
            db = results.result.params['b'].stderr
            m = results.result.params['m'].value
            dm = results.result.params['m'].stderr

            x = np.linspace(0.65,2.5,100)
            fit = lambda x: func(x, m, b)
            dy = lambda x: fit(x) * np.sqrt((dm/m)**2 + (np.log(x)*db)**2)

            ax.plot(x, func(x, 11.99, 1.5), 'k-', label=r'$\alpha\mu^{*3/2} = \alpha_oC^{\ast 3/2}$')
            for key in grains:
                ax.errorbar(self.xf[key].CDx, self.xf[key].c, xerr=self.xf[key].dCDx, yerr=self.xf[key].dc,
                ecolor=sns.xkcd_rgb['dark grey'], elinewidth=1, capthick=.8, capsize=3, color=color[key], markeredgecolor='w', markeredgewidth=1, ms=size[key]-2, fmt=marker[key])

            ax.legend(fontsize=fontsize-1, loc=4)
            ax.ticklabel_format(useOffset=False, style='plain')
            
            ax.set_xlim(0,3)
            ax.set_ylim(0,50)

        else:
            mu, da, a = [], [], []
            for key in grains:
                self.xf[key].fit_model_mus()
                C.append(self.xf[key].CDX)
                a.append(self.xf[key].c)
                da.append(self.xf[key].dc)

            x_data, dy_data, y_data = np.array(mu), np.array(da), np.array(a)

            model = lm.Model(func)
            model.set_param_hint('b', value=-1.5, vary=False)
            model.set_param_hint('m', value=1.0)
            params = model.make_params()
            results = model.fit(y_data, params, weights=y_data/dy_data, x=x_data)
            self.alpha_b = results.result.params['b'].value
            self.alpha_db = results.result.params['b'].stderr
            self.alpha_m = results.result.params['m'].value
            self.alpha_dm = results.result.params['m'].stderr
            self.alpha_func = lambda x, m, b: func(x, m, b)

    def plot_taucs_2(self, grains=None, ax=None, plot=True):
        if not grains: grains = self.grains
        fontsize = 18

        func = lambda x, m, b: m*x**b

        if plot is True:
            if ax is None: fig, ax = plt.subplots(figsize=(8,6))

            ax.tick_params(labelsize=fontsize-2)
            C, dC, da, a = [], [], [], []
            for key in grains:
                self.xf[key].fit_model_mus()
                C.append(self.xf[key].CDx)
                dC.append(self.xf[key].dCDx)
                a.append(self.xf[key].xc)
                da.append(self.xf[key].dxc)
            ax.set_xlabel(r'Norm. drag coeff., $C^*$', fontsize=fontsize+2, labelpad=12)
            ax.set_ylabel(r'$\tau_c^*/\mu^*$', fontsize=fontsize+2)
            # ax.set_ylabel('Friction-corrected threshold of motion, '+r'$\tau_c^*/\mu^*$', fontsize=fontsize+2)
            x_data, dx_data, dy_data, y_data = np.array(C), np.array(dC), np.array(da), np.array(a)

            model = lm.Model(func)
            model.set_param_hint('b', value=-1, vary=False)
            model.set_param_hint('m', value=0.0506, vary=True)
            params = model.make_params()
            results = model.fit(y_data, params, weights=x_data/dx_data, x=x_data)
            b = results.result.params['b'].value
            db = results.result.params['b'].stderr
            m = results.result.params['m'].value
            dm = results.result.params['m'].stderr

            x = np.linspace(0.8,2.5,100)
            fit = lambda x: func(x, m, b)
            dy = lambda x: fit(x) * np.sqrt((dm/m)**2 + (np.log(x)*db)**2)
            ax.plot(x, func(x, 0.0506, -1), 'k-', label=r'$\tau_c^*/\mu^* = \tau_{co}^*/C^\ast$')
            for key in grains:
                ax.errorbar(self.xf[key].CDx, self.xf[key].xc, xerr=self.xf[key].dCDx, yerr=self.xf[key].dxc,
                ecolor=sns.xkcd_rgb['dark grey'], elinewidth=1, capthick=.8, capsize=3, color=color[key], markeredgecolor='w', markeredgewidth=1, ms=size[key]-2, fmt=marker[key])

            ax.legend(fontsize=fontsize-1, loc=3)
            ax.set_xlim(0,3)
            ax.set_ylim(0,.08)

        else:
            C, da, a = [], [], []
            for key in grains:
                self.xf[key].fit_model_mus()
                C.append(self.xf[key].CDx)
                a.append(self.xf[key].xc)
                da.append(self.xf[key].dxc)
            x_data, dy_data, y_data = np.array(C), np.array(da), np.array(a)

            model = lm.Model(func)
            model.set_param_hint('b', value=1, vary=False)
            model.set_param_hint('m', value=0.06)
            params = model.make_params()
            results = model.fit(y_data, params, weights=y_data/dy_data, x=x_data)
            self.xc_b = results.result.params['b'].value
            self.xc_db = results.result.params['b'].stderr
            self.xc_m = results.result.params['m'].value
            self.xc_dm = results.result.params['m'].stderr
            self.xc_func = lambda x, m, b: func(x, m, b)
